
import numpy as np
import torch
from torch import nn
import device
import sys

from torchsummary import summary

finalFc = False
initFc = False
initUnit = True

class FwdBwdNeuralEq(nn.Module):
	def __init__(self,row,col,main,depth,batchSize,mod):
		super().__init__()

		# row must be greater or equal than 2
		self.batchSize = batchSize
		self.main = main
		self.row = row
		self.col = col
		self.mod = mod
		self.depth = depth
		if (mod == 'nrz'):
			self.modNum = 2
		elif (mod == 'pam4'):
			self.modNum = 4
		elif (mod == 'pam8'):
			self.modNum = 8

		self.nnUnit = np.zeros([self.row, self.col]).tolist()
		
		for k in range(self.row):
			for i in range(self.col):
				if (k==0):
					self.nnUnit[k][i] = nn.Sequential(
													nn.Linear(1 , self.depth, device=device.device),  nn.Tanh(),
													#nn.Linear(1	,4,device=device.device),  nn.Tanh(),
													#nn.Linear(4 , self.depth, device=device.device),  nn.Tanh(),
													)
				else:
					if (i == 0):
						self.nnUnit[k][i] = nn.Sequential(
														nn.Linear(self.depth		,self.depth,device=device.device), 
														nn.Tanh(), 
														nn.Linear(self.depth, self.depth, device=device.device), 
														nn.Tanh()
														)
					elif (i < self.main):                                                                                      
						self.nnUnit[k][i] = nn.Sequential(
														nn.Linear(self.depth*2	,self.depth,device=device.device), 
														nn.Tanh(), 
														nn.Linear(self.depth, self.depth, device=device.device), 
														nn.Tanh()
														)
					elif (i == self.main):                                                                                    
						self.nnUnit[k][i] = nn.Sequential(
														nn.Linear(self.depth*2	,self.depth,device=device.device), 
														nn.Tanh(), 
														nn.Linear(self.depth, self.depth, device=device.device), 
														nn.Tanh()
														)
					elif (i < self.col-1):                                                                                   
						self.nnUnit[k][i] = nn.Sequential(
														nn.Linear(self.depth*2	,self.depth,device=device.device), 
														nn.Tanh(), 
														nn.Linear(self.depth, self.depth, device=device.device), 
														nn.Tanh()
														)
					else:                                                                                                   
						self.nnUnit[k][i] = nn.Sequential(
														nn.Linear(self.depth		,self.depth,device=device.device), 
														nn.Tanh(), 
														nn.Linear(self.depth, self.depth, device=device.device), 
														nn.Tanh()
														)
					#if (i == 0):
					#	self.nnUnit[k][i] = nn.Sequential(nn.Linear(self.depth		,self.depth, device=device.device), nn.Tanh() )
					#elif (i < self.main):                                                                                       
					#	self.nnUnit[k][i] = nn.Sequential(nn.Linear(self.depth*2	,self.depth, device=device.device), nn.Tanh() )
					#elif (i == self.main):                                                                                     
					#	self.nnUnit[k][i] = nn.Sequential(nn.Linear(self.depth*2	,self.depth, device=device.device), nn.Tanh() )
					#elif (i < self.col-1):                                                                                    
					#	self.nnUnit[k][i] = nn.Sequential(nn.Linear(self.depth*2	,self.depth, device=device.device), nn.Tanh() )
					#else:                                                                                                    
					#	self.nnUnit[k][i] = nn.Sequential(nn.Linear(self.depth		,self.depth, device=device.device), nn.Tanh() )


		#self.nnFinalUnit = nn.Linear(self.depth, self.modNum)
		self.nnFinalUnit = nn.Sequential(
										nn.Linear(self.depth*2, self.depth), 
										nn.Tanh(), 
										nn.Linear(self.depth, self.modNum)
										)

		tmp = np.array(self.nnUnit).flatten()
		tmp = tmp.tolist()
		#print(tmp)
		self.nn = nn.ModuleList(tmp)

		self.a = torch.zeros(self.row, self.col, batchSize,self.depth, device=device.device)
		self.aFinal= torch.zeros(batchSize,1,device=device.device)

	#def prunedMasking(self, mask, row, col, inner):
	#	self.nnUnit[row][col][inner] = nn.utils.prune.custom_from_mask(self.nnUnit[row][col][inner], name='weight', mask=mask)


	def forward(self, x):
		#print(f"x[0] : {x[0]}")
		#print(f"self.a[0,0] : {self.a[0,0]}")
		#x=x.reshape(self.batchSize,-1,1)
		self.a = self.a.to(device.device)
		self.aFinal = self.aFinal.to(device.device)
		if (x.shape[-1] != 1):
			x=x.reshape(self.batchSize,-1,1)
		for k in range(self.row):
			for i in range(self.col):
				if (k==0):
					if (i <= self.main):
						self.a[k,i] = self.nnUnit[k][i](x[:,i])
						#if (i == self.main):
						#	print(f" x[0,i]: {x[0,i]}")
						#	print(f" x[1,i]: {x[1,i]}")
						#	print(f" x[2,i]: {x[2,i]}")
				else :
					if (i==0):
						#print(f"a[k-1,i].device: {self.a[k-1,i].get_device()}")
						self.a[k,i] = self.nnUnit[k][i](self.a[k-1,i].clone())
					elif (i < self.main):
						self.a[k,i] = self.nnUnit[k][i]( torch.cat((self.a[k,i-1].clone(),self.a[k-1,i].clone()),-1) )
						#print(f"i: {i}")
					elif (i == self.main):
						self.a[k,i] = self.nnUnit[k][i]( torch.cat((self.a[k,i-1].clone(),self.a[k-1,i].clone()),-1) )
						#print(f"i: {i}")
						#sys.exit()
					#elif (i < self.col-1):
					#	self.a[k,i] = self.nnUnit[k][i]( torch.cat((self.a[k,i+1].clone(),self.a[k-1,i].clone()),-1) )
					#else:
					#	self.a[k,i] = self.nnUnit[k][i](self.a[k-1,i].clone())

		for k in range(self.row):
			for i in list(range(self.col))[::-1]:
				if (k==0):
					if (i > self.main):
						self.a[k,i] = self.nnUnit[k][i](x[:,i])
				else :

					if(i == self.col-1):
						#print(f"i: {i}")
						self.a[k,i] = self.nnUnit[k][i](self.a[k-1,i].clone())
					elif (i > self.main):
						#print(f"k: {k}, i: {i}")
						#print(f"self.a[k-1,i].shape: {self.a[k-1,i].shape}")
						#print(f"self.a[k,i+1].shape: {self.a[k,i+1].shape}")
						self.a[k,i] = self.nnUnit[k][i]( torch.cat((self.a[k,i+1].clone(),self.a[k-1,i].clone()),-1) )
						#print(f"i: {i}")
						#sys.exit()


		self.aFinal = self.nnFinalUnit( torch.cat( (self.a[self.row-1][self.main].clone(), self.a[self.row-1][self.main+1].clone()),-1 ) )
		out = self.aFinal

		if 0:
			print(f"x.shape: {x.shape}")
			print(f"self.aFinal.shape: {self.aFinal.shape}")
			print(f"x: {x}")
			print(f"self.aFinal: {self.aFinal}")
			sys.exit()
		return out
	def print(self):
		print(self.a)

	def detachA(self):
		self.a = self.a.detach()
		self.aFinal = self.aFinal.detach()

	def measureModuleSparsity(self, module, weight=True, bias=False, useMask=False):
		numZeros = 0
		numElements = 0

		if useMask == True:
			for bufferName, buffer in module.named_buffers():
				if "weight_mask" in bufferName and weight == True:
					numZeros += torch.sum(buffer == 0).item()
					numElements += buffer.nelement()
				if "bias_mask" in bufferName and bias == True:
					numZeros += torch.sum(buffer == 0).item()
					numElements += buffer.nelement()
		else:
			for paramName, param in module.named_parameters():
				if "weight" in paramName and weight == True:
					numZeros += torch.sum(param == 0).item()
					numElements += param.nelement()
				if "bias" in paramName and bias == True:
					numZeros += torch.sum(param == 0).item()
					numElements += param.nelement()

		sparsity = numZeros / numElements

		return numZeros, numElements, sparsity

	def measureNnSeqSparsity(self, weight=True, bias=False, useMask=True):
		'''
		Excute this only after pruning, otherwise it causes error.
		'''
		numZerosList = np.zeros([self.row, self.col]).tolist()
		numElementsList = np.zeros([self.row, self.col]).tolist()
		sparsityList = np.zeros([self.row, self.col]).tolist()
		numZerosFinalList = []
		numElementsFinalList = []
		sparsityFinalList = []
		for k in range(self.row):
			for i in range(self.col):
				firstAppend = True
				for j in range(len(self.nnUnit[k][i])):
					if isinstance(self.nnUnit[k][i][j] , torch.nn.Linear):
						numZeros, numElements, sparsity = self.measureModuleSparsity(
																				self.nnUnit[k][i][j], 
																				weight=weight,
																				bias=bias,
																				useMask=useMask
																				)
						if firstAppend:
							numZerosList[k][i] = numZeros
							numElementsList[k][i] = numElements
							sparsityList[k][i] = sparsity
							firstAppend = False
						else:
							numZerosList[k][i]+=numZeros
							numElementsList[k][i]+=numElements
							sparsityList[k][i]=numZerosList[k][i]/numElementsList[k][i]
		
		for j in range(len(self.nnFinalUnit)):
			if isinstance(self.nnFinalUnit[j], torch.nn.Linear):
				numZeros, numElements, sparsity = self.measureModuleSparsity(
																			self.nnFinalUnit[j],
																			weight=weight,
																			bias=bias,
																			useMask=useMask
																			)
				numZerosFinalList.append(numZeros)
				numElementsFinalList.append(numElements)
				sparsityFinalList.append(sparsity)
		if 0:
			for k in range(self.row):
				print("")
				for i in range(self.col):
					print(f"{sparsityList[k][i]} {numZerosList[k][i]}/{numElementsList[k][i]}", end="\t"),
			print("")
			print(f"{sparsityFinalList}")
		return numZerosList, numElementsList, sparsityList, numZerosFinalList, numElementsFinalList, sparsityFinalList
	
	


if __name__ == "__main__":
	batchSize = 8
	x = torch.ones(batchSize,10,1, device=device.device)
	nEQ = FwdBwdNeuralEq(2,10,5,2,batchSize,'nrz')
	#print(nEQ)
	opt = torch.optim.Adam(nEQ.parameters())
	for name, param in nEQ.named_parameters():
		if name=='nn.0.weight':
			print (f"name:{name} params:\n{param}")

	print(x.shape)
	summary(nEQ, (8,10), batch_size=batchSize, device=device.device)



	pred = nEQ(x)
	loss = pred.sum()

	print(f"pred:{pred}")
	print(f"loss:{loss}")
	opt.zero_grad()
	loss.backward()
	opt.step()

	#for name, param in nEQ.named_parameters():
	#	if name=='nn.0.weight':
	#		print (f"name:{name} params:\n{param}")
	#		print (f"{param.grad}")
	print('hi')


	nEQ.detachA()
	y = torch.zeros(batchSize,10,1, device=device.device)
	pred = nEQ(y)
	loss = pred.sum()
	print(f"pred:{pred}")
	print(f"loss:{loss}")

	opt.zero_grad()
	#for name, param in nEQ.named_parameters():
	#	if name=='nn.0.weight':
	#		print (f"name:{name} params:\n{param}")
	#		print (f"{param.grad}")
	loss.backward()
	opt.step()
	#print (pred)
	#nEQ.print()

	#x = torch.zeros(2,2)
	#nn.Linear(2,1)(x)
